import json
import os
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from rich.theme import Theme
from rich.live import Live
from rich.text import Text
from rich.console import Console, Group
import sys
from prompt_toolkit import HTML
from prompt_toolkit import PromptSession
from prompt_toolkit import print_formatted_text
from prompt_toolkit.formatted_text import FormattedText
import subprocess
import shlex
from app.utils.cut_history import cut_history

class PentestAgent:
    """
    Agent to handle pentest tasks.
    """
    def __init__(self, client, data_dir="./logs/", task_id=None):
        self.client = client
        self.data_dir = data_dir
        self.task_id = task_id
        
        # Define a custom theme with more subdued colors
        custom_theme = Theme({
            "markdown": "",
            "markdown.heading": "",
            "markdown.code": "",
            "markdown.pre": "",
            "markdown.link": "",
            "markdown.list": "",
            "markdown.strong": "",
            "markdown.emphasis": "",
            "markdown.block_quote": "",
            "repr.number": "",
            "repr.string": "",
            "repr.string_quote": "",
            "markdown.table": "",
            "markdown.table.header": "",
            "markdown.table.row": "",
            "markdown.table.cell": "",
            "markdown.hrule": ""
        })
        self.console = Console(theme=custom_theme)
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)


    def start(self):
        self.console.print(Text("────────────────────────────────────────────────────────────────\n\n", style="magenta"), end="")
        self.console.print(Text("Enter your message (single line or use ''' for multiple lines).\n", style="blue"))

        panel_group = Group(
            Text("Pentest Muse: ", style="bold green", end=""),
            Text("Hi! What task do you want me to perform?"),
        )
        self.console.print(panel_group)        

        while True:
            # Get user input
            task = self.get_task()
            if not task:
                sys.exit()

            # Set the user input as the task
            self.set_task(task)

            # Start the thought-action-obseration loop
            while True:
                try:
                    # Generate a thought about next step
                    thought = self.generate_thought()

                    if thought is None:
                        break

                    # Generate an action (command line code) based on the thought
                    action, status = self.determine_next_action(thought)
                    if status == 'stop':
                        break                
                    else:
                        # Execute the action
                        execution_response = self.execute_action(action)

                        # Print the output
                        if len(execution_response['output']) > 0:
                            panel_group = Group(
                                Text("\nSystem Output: \n", style="bold magenta", end=""),
                                Panel(execution_response['output'], expand=False, border_style="magenta")
                            )
                            self.console.print(panel_group)   

                        if len(execution_response['error']) > 0:
                            panel_group = Group(
                                Text("\nSystem Error: \n", style="bold red", end=""),
                                Panel(execution_response['error'], expand=False, border_style="red")
                            )
                            self.console.print(panel_group)   

                        # Save the result
                        self.save_execution_result(execution_response)
                except KeyboardInterrupt:
                    # if the user press Ctrl+C, stop the loop and let the user input the next task
                    break
            
    def get_history_file_path(self):
        return os.path.join(self.data_dir, f"agent_{self.task_id}.json")

    def load_history(self):
        history_file = self.get_history_file_path()
        if not os.path.exists(history_file):
            return []

        with open(history_file, "r") as file:
            return json.load(file)

    def save_history(self, history):
        history_file = self.get_history_file_path()
        with open(history_file, "w") as file:
            json.dump(history, file, indent=4)

    def set_task(self, task):
        """
        Set the task to be worked on to the history.
        """
        self.task = task
        history = self.load_history()
        history.append({"role": "user", "content": task})
        self.save_history(history)

    def get_task(self):
        """
        Get the task from user input using prompt_toolkit for advanced input handling.
        """
        session = PromptSession()

        self.console.print(Text("\n"), end="")

        try:
            lines = []
            multiline = False
            while True:
                line = session.prompt(HTML('> ') if not multiline else "")
                if line.strip() == "'''":
                    multiline = not multiline
                elif line.strip() == "exit":
                    sys.exit()
                elif line.strip() == "":
                    continue
                else:
                    lines.append(line)
                if not multiline:  # End of multiline input
                    break

            user_input = "\n".join(lines)
            
            return user_input
        except KeyboardInterrupt:
            if lines:
                return self.get_task()
            else:
                sys.exit()
        except Exception as e:
            print_formatted_text(FormattedText([("bold red", f"Error occurred while getting user input: {e}")]))
            return None


    def generate_thought(self):
        """
        Generate a thought about the next step.
        """
        # Load history
        history = self.load_history()

        from app.prompts.prompts import PENTEST_THOUGHT_PROMPT
        instruction = PENTEST_THOUGHT_PROMPT

        # Delete all system messages in the history, except for the last one if there is any. 
        # This is to reduce the length of the history as some previous system responses may be too long.
        filtered_history = []
        last_system_message = None
        for message in history[::-1]:
            if message["role"] == "system":
                if last_system_message:
                    continue
                last_system_message = message
            filtered_history.append(message)
        filtered_history.reverse()   

        filtered_history = cut_history(filtered_history, length=128000)             

        messages = [{"role": "system", "content": instruction}] + filtered_history

        full_message_content = ''

        with Live(console=self.console, refresh_per_second=10) as live:
            try:
                panel_group = Group(
                    Text("\nPentest Muse: \n", style="bold green", end=""),
                    Panel('Thinking...', expand=False, border_style="green"),
                )
                live.update(panel_group)

                # Begin the streaming completion with OpenAI
                completion = self.client.generate(messages=messages)

                for chunk in completion:
                    # Update the full message content
                    full_message_content += chunk

                    # Update the live output with the new content
                    response_markdown = Markdown(full_message_content)
                    panel_group = Group(
                        Text("\nPentest Muse: \n", style="bold green", end=""),
                        Panel(response_markdown, expand=False, border_style="green"),
                    )
                    live.update(panel_group)
            except KeyboardInterrupt:
                # Save to history
                history.append({"role": "assistant", "content": full_message_content})
                self.save_history(history)
                return None

        # Save to history
        history.append({"role": "assistant", "content": full_message_content})
        self.save_history(history)

        return full_message_content

    def determine_next_action(self, thought):
        """
        Determine the next action based on the thought.
        """
        include_command = self.include_command(thought)
        if not include_command:
            return None, 'stop'

        next_action = self.extract_command(thought)
        return next_action, 'continue'

    def extract_command(self, thought):
        """
        Uses an OpenAI GPT model to extract a command from the AI's response.
        """
        # Prepare the prompt for GPT model
        from app.prompts.prompts import THOUGHT_TO_COMMAND_PROMPT
        prompt = THOUGHT_TO_COMMAND_PROMPT

        full_command_content = ''

        try:
            with Live(console=self.console, refresh_per_second=10) as live:
                panel_group = Group(
                    Text("\nCommand to Execute: \n", style="bold cyan", end=""),
                    Panel('Extracting command...', expand=False, border_style="cyan")
                )
                live.update(panel_group)

                # Begin the streaming completion with OpenAI
                completion = self.client.generate(messages=[{"role": "system", "content": prompt}] + [{"role": "user", "content": thought}])

                for chunk in completion:
                    # Update the full command content
                    full_command_content += chunk

                    # Update the live output with the new content
                    command_markdown = Markdown(full_command_content)
                    panel_group = Group(
                        Text("\nCommand to Execute: \n", style="bold cyan", end=""),
                        Panel(command_markdown, expand=False, border_style="cyan")
                    )
                    live.update(panel_group)

        except KeyboardInterrupt:
            return

        return full_command_content

    def include_command(self, thought):
        """
        Uses an OpenAI GPT model to determine if the AI's response includes a command line code.
        """
        # Prepare the prompt for GPT model
        prompt = f"An AI assitant is helping a penetration tester work on his job. At each iteration, the AI gives the penetration tester some ideas for him to try. At some point, the AI gives the human the following response\n\nResponse: '{thought}'\n\n Determine if there's any command line code that can be executed in the response. \n\n Is there any command line code that can be executed in the response? Respond 'yes' if there is, respond 'no' there's no clear command line code in the response. Give me a clear 'yes' or 'no', don't say additional words."


        completion = self.client.generate(messages=[{"role": "system", "content": "You are a helpful assistant"}]+[{"role": "user", "content": prompt}])

        full_message_content = ''
        for chunk in completion:
            full_message_content += chunk

        if "yes" in full_message_content.lower():
            return True
        elif "no" in full_message_content.lower():
            return False

    def execute_action(self, command):
        """
        Execute a command in the terminal and return the output.
        """
        # Save the command in the history
        history = self.load_history()
        history.append({"role": "user", "content": command})
        self.save_history(history)

        # Execute the command
        try:
            # Safely split the command into a sequence of arguments
            args = shlex.split(command)

            # Execute the command with a timeout of 60 seconds
            result = subprocess.run(args, capture_output=True, text=True, check=True, timeout=60)

            return {
                "output": result.stdout,
                "error": ""
            }
        except subprocess.TimeoutExpired:
            # Handle the timeout case
            return {
                "output": "",
                "error": "Command timed out after 60 seconds"
            }
        except subprocess.CalledProcessError as e:            
            return {
                "output": e.stdout,
                "error": e.stderr
            }
        except Exception as e:
            return {
                "output": "",
                "error": str(e)
            }
        
    def save_execution_result(self, r):
        history = self.load_history()
        # reduce the length of error and output to 1000 characters
        if len(r['error']) > 1000:
            r['error'] = r['error'][:500] + " ... " + r['error'][-500:]
        r = json.dumps(r)
        history.append({"role": "system", "content": r})
        self.save_history(history)
